import numpy as np
from scipy import sparse
#from variables import (kus, kys, kvs, one_m, one_n, one_l, V0, Vy, Omega0, t0, tSO, ty, mp1_m_matrix, np1_n_matrix, lp1_l_matrix)
from variables import (kus, kys, kvs, one_m, one_n, one_l, t0, tSO, ty, mp1_m_matrix, np1_n_matrix, lp1_l_matrix)

# The sign of this Hamiltonian is the same as Sun's scheme notes.
def Hp(ku, ky, kv):
    hup11 = sparse.kron(sparse.kron(np.diag(np.power(ku+2*kus, 2)), one_n), one_l)/2
    hup12 = sparse.kron(sparse.kron(one_m, np.diag(np.power(ky+2*kys, 2))), one_l)
    hup13 = sparse.kron(sparse.kron(one_m, one_n), np.diag(np.power(kv+2*kvs, 2)))/2
    
    hdown11 = sparse.kron(sparse.kron(np.diag(np.power(ku+2*kus+1, 2)), one_n), one_l)/2
    hdown12 = sparse.kron(sparse.kron(one_m, np.diag(np.power(ky+2*kys+1, 2))), one_l)
    hdown13 = sparse.kron(sparse.kron(one_m, one_n), np.diag(np.power(kv+2*kvs+1, 2)))/2
    
    hp1 = hup11+hup12+hup13
    hp4 = hdown11+hdown12+hdown13

    return hp1, hp4

def Hspin(mz, t1, t2, t3, t4):
    hup21 = (t2/2+t1+mz)*sparse.kron(sparse.kron(one_m, one_n), one_l)

    hup31 = t1/4*sparse.kron(sparse.kron(mp1_m_matrix, one_n), one_l)
    hup32 = t1/4*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), one_l)
    hup33 = t1/4*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix)
    hup34 = t1/4*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix.T)
    hup35 = t2/4*sparse.kron(sparse.kron(one_m, np1_n_matrix), one_l)
    hup36 = t2/4*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), one_l)

    hup41 = t3/8j*sparse.kron(sparse.kron(one_m, one_n), one_l)
    hup42 = -t3/8j*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), one_l)
    hup43 = t3/8j*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), one_l)
    hup44 = -t3/8j*sparse.kron(sparse.kron(mp1_m_matrix.T, np1_n_matrix.T), one_l)
    hup45 = t3/8j*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix.T)
    hup46 = -t3/8j*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), lp1_l_matrix.T)
    hup47 = t3/8j*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), lp1_l_matrix.T)
    hup48 = -t3/8j*sparse.kron(sparse.kron(mp1_m_matrix.T, np1_n_matrix.T), lp1_l_matrix.T)

    hup51 = -t4/8*sparse.kron(sparse.kron(one_m, one_n), one_l)
    hup52 = -t4/8*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), one_l)
    hup53 = -t4/8*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), one_l)
    hup54 = -t4/8*sparse.kron(sparse.kron(mp1_m_matrix.T, np1_n_matrix.T), one_l)
    hup55 = t4/8*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix.T)
    hup56 = t4/8*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), lp1_l_matrix.T)
    hup57 = t4/8*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), lp1_l_matrix.T)
    hup58 = t4/8*sparse.kron(sparse.kron(mp1_m_matrix.T, np1_n_matrix.T), lp1_l_matrix.T)

    hdown21 = (t2/2+t1-mz)*sparse.kron(sparse.kron(one_m, one_n), one_l)

    hdown31 = t1/4*sparse.kron(sparse.kron(mp1_m_matrix, one_n), one_l)
    hdown32 = t1/4*sparse.kron(sparse.kron(mp1_m_matrix.T, one_n), one_l)
    hdown33 = t1/4*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix)
    hdown34 = t1/4*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix.T)
    hdown35 = t2/4*sparse.kron(sparse.kron(one_m, np1_n_matrix), one_l)
    hdown36 = t2/4*sparse.kron(sparse.kron(one_m, np1_n_matrix.T), one_l)

    hdown41 = -t3/8j*sparse.kron(sparse.kron(one_m, one_n), one_l)
    hdown42 = t3/8j*sparse.kron(sparse.kron(mp1_m_matrix, one_n), one_l)
    hdown43 = -t3/8j*sparse.kron(sparse.kron(one_m, np1_n_matrix), one_l)
    hdown44 = t3/8j*sparse.kron(sparse.kron(mp1_m_matrix, np1_n_matrix), one_l)
    hdown45 = -t3/8j*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix)
    hdown46 = t3/8j*sparse.kron(sparse.kron(mp1_m_matrix, one_n), lp1_l_matrix)
    hdown47 = -t3/8j*sparse.kron(sparse.kron(one_m, np1_n_matrix), lp1_l_matrix)
    hdown48 = t3/8j*sparse.kron(sparse.kron(mp1_m_matrix, np1_n_matrix), lp1_l_matrix)
    
    hdown51 = -t4/8*sparse.kron(sparse.kron(one_m, one_n), one_l)
    hdown52 = -t4/8*sparse.kron(sparse.kron(mp1_m_matrix, one_n), one_l)
    hdown53 = -t4/8*sparse.kron(sparse.kron(one_m, np1_n_matrix), one_l)
    hdown54 = -t4/8*sparse.kron(sparse.kron(mp1_m_matrix, np1_n_matrix), one_l)
    hdown55 = t4/8*sparse.kron(sparse.kron(one_m, one_n), lp1_l_matrix)
    hdown56 = t4/8*sparse.kron(sparse.kron(mp1_m_matrix, one_n), lp1_l_matrix)
    hdown57 = t4/8*sparse.kron(sparse.kron(one_m, np1_n_matrix), lp1_l_matrix)
    hdown58 = t4/8*sparse.kron(sparse.kron(mp1_m_matrix, np1_n_matrix), lp1_l_matrix)

    hu2_3 = hup21+hup31+hup32+hup33+hup34+hup35+hup36
    hd4_5 = hdown41+hdown42+hdown43+hdown44+hdown45+hdown46+hdown47+hdown48\
    +hdown51+hdown52+hdown53+hdown54+hdown55+hdown56+hdown57+hdown58
    hu4_5 = hup41+hup42+hup43+hup44+hup45+hup46+hup47+hup48\
    +hup51+hup52+hup53+hup54+hup55+hup56+hup57+hup58
    hd2_3 = hdown21+hdown31+hdown32+hdown33+hdown34+hdown35+hdown36

    return hu2_3, hd4_5, hu4_5, hd2_3

def H_TBM(ku, ky, kv, mz):
    hsgx = 2*tSO*sparse.kron(sparse.kron(np.diag(np.sin(ku+kus)), one_n), one_l)
    hsgy = 2*tSO*sparse.kron(sparse.kron(one_m, one_n), np.diag(np.sin(kv+kvs)))
    hsgz1 = mz*sparse.kron(sparse.kron(one_m, one_n), one_l)
    hsgz2 = -2*t0*sparse.kron(sparse.kron(np.diag(np.cos(ku+kus)), one_n), one_l)
    hsgz3 = -2*t0*sparse.kron(sparse.kron(one_m, one_n), np.diag(np.cos(kv+kvs)))
    hsgz4 = -2*ty*sparse.kron(sparse.kron(one_m, np.diag(np.cos(ky+kys))), one_l)
    hsgz = hsgz1+hsgz2+hsgz3+hsgz4
    return hsgx, hsgy, hsgz

    